package com.yifeng.repo.toolkit.elasticsearch;

import com.fasterxml.jackson.databind.JsonNode;
import com.yifeng.repo.base.utils.converter.JacksonHelper;
import com.yifeng.repo.toolkit.elasticsearch.configure.ElasticsearchProperties;
import lombok.extern.slf4j.Slf4j;
import org.elasticsearch.action.get.GetRequest;
import org.elasticsearch.action.get.GetResponse;
import org.elasticsearch.action.search.*;
import org.elasticsearch.client.RequestOptions;
import org.elasticsearch.client.RestHighLevelClient;
import org.elasticsearch.client.core.CountRequest;
import org.elasticsearch.client.core.CountResponse;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.index.query.QueryBuilder;
import org.elasticsearch.search.Scroll;
import org.elasticsearch.search.SearchHit;
import org.elasticsearch.search.aggregations.Aggregation;
import org.elasticsearch.search.aggregations.AggregationBuilder;
import org.elasticsearch.search.aggregations.bucket.MultiBucketsAggregation;
import org.elasticsearch.search.aggregations.metrics.ParsedCardinality;
import org.elasticsearch.search.builder.SearchSourceBuilder;
import org.elasticsearch.search.fetch.subphase.FetchSourceContext;
import org.elasticsearch.search.sort.SortOrder;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

/**
 * Created by daibing on 2021/5/18.
 */
@Slf4j
public class ElasticSearchQueryWorker {
    private final ElasticsearchProperties properties;
    private final RestHighLevelClient esClient;

    public ElasticSearchQueryWorker(ElasticSearchClient client) {
        // 默认仅支持单集群
        this.esClient = client.getClient();
        this.properties = client.getProperties();
    }

    public Map<String, Object> select(String indexName, String id) {
        GetRequest request = new GetRequest(indexName, id);
        try {
            GetResponse response = esClient.get(request, RequestOptions.DEFAULT);
            Map<String, Object> map = new HashMap<>();
            map.put("_id", response.getId());
            map.put("_type", response.getType());
            map.putAll(response.getSourceAsMap());
            return map;
        } catch (Throwable t) {
            log.warn("select failed: indexName={}, id={}, error=", indexName, id, t);
            throw new RuntimeException(String.format("select failed: indexName=%s, id=%s, error=%s", indexName, id, t));
        }
    }

    public List<Map<String, Object>> select(String indexName, FetchSourceContext source, QueryBuilder condition, Map<String, SortOrder> sorts, int pageNo, int pageSize) {
        // 1. 准备查询: 查询字段、查询条件、排序、分页
        SearchSourceBuilder builder = new SearchSourceBuilder();
        builder.fetchSource(source);
        builder.query(condition);
        for (Map.Entry<String, SortOrder> entry : sorts.entrySet()) {
            builder.sort(entry.getKey(), entry.getValue());
        }
        builder.from(pageNo).size(pageSize);

        // 2. 执行查询：解析结果
        SearchRequest request = new SearchRequest(new String[]{indexName}, builder);
        if (properties.getRequestCache() != null) {
            request.requestCache(properties.getRequestCache());
        }
        try {
            SearchResponse response = esClient.search(request, RequestOptions.DEFAULT);
            List<Map<String, Object>> list = new ArrayList<>();
            for (SearchHit hit : response.getHits().getHits()) {
                Map<String, Object> map = new HashMap<>();
                map.put("_id", hit.getId());
                if (hit.getSourceAsMap() != null) {
                    map.putAll(hit.getSourceAsMap());
                }
                list.add(map);
            }
            log.info("select ok: indexName={}, pageNo={}, pageSize={}, count={}, timeout={}, tookMillis={}, totalHit={}",
                    indexName, pageNo, pageSize, list.size(), response.isTimedOut(), response.getTook().millis(), response.getHits().getTotalHits().value);
            return list;
        } catch (Throwable t) {
            log.warn("select failed: indexName={}, pageNo={}, pageSize={}, error=", indexName, pageNo, pageSize, t);
            throw new RuntimeException(String.format("select failed: indexName=%s, pageNo=%s, pageSize=%s, error=%s", indexName, pageNo, pageSize, t));
        }
    }

    public long count(String indexName, QueryBuilder condition) {
        try {
            CountRequest countRequest = new CountRequest(new String[]{indexName}, condition);
            CountResponse response = esClient.count(countRequest, RequestOptions.DEFAULT);
            log.info("count ok: indexName={}, count={}, totalShards={}, successfulShards={}, failedShards={}, skippedShards={}",
                    indexName, response.getCount(), response.getTotalShards(), response.getSuccessfulShards(), response.getFailedShards(), response.getSkippedShards());
            return response.getCount();
        } catch (Throwable t) {
            log.warn("count failed: indexName={}, error=", indexName, t);
            throw new RuntimeException(String.format("count failed: indexName=%s, error=%s", indexName, t));
        }
    }

    public Map<String, Object> aggsOne(String indexName, QueryBuilder condition, AggregationBuilder aggregation) {
        return this.aggsOne(indexName, condition, aggregation, 5);
    }

    public Map<String, Object> aggsOne(String indexName, QueryBuilder condition, AggregationBuilder aggregation, int maxConcurrentShardRequests) {
        // 1. 准备查询: 查询条件、聚合方法
        SearchSourceBuilder builder = new SearchSourceBuilder();
        builder.query(condition);
        builder.size(0);
        builder.aggregation(aggregation);

        // 2. 执行查询：解析结果
        SearchRequest request = new SearchRequest(new String[]{indexName}, builder);
        request.setMaxConcurrentShardRequests(maxConcurrentShardRequests);
        if (properties.getRequestCache() != null) {
            request.requestCache(properties.getRequestCache());
        }
        try {
            SearchResponse response = esClient.search(request, RequestOptions.DEFAULT);
            Aggregation result = response.getAggregations().asMap().get(aggregation.getName());
            Map<String, Object> map = new HashMap<>();
            ParsedCardinality parsedCardinality = (ParsedCardinality) result;
            map.put("name", parsedCardinality.getName());
            map.put("value", parsedCardinality.getValue());
            log.info("aggsOne ok: indexName={}, timeout={}, tookMillis={}, totalHit={}",
                    indexName, response.isTimedOut(), response.getTook().millis(), response.getHits().getTotalHits().value);
            return map;
        } catch (Throwable t) {
            log.warn("aggsOne failed: indexName={}, error=", indexName, t);
            throw new RuntimeException(String.format("aggsOne failed: indexName=%s, error=%s", indexName, t));
        }
    }

    public List<Map<String, Object>> aggs(String indexName, QueryBuilder condition, AggregationBuilder aggregation) {
        return this.aggs(indexName, condition, aggregation, 5);
    }

    public List<Map<String, Object>> aggs(String indexName, QueryBuilder condition, AggregationBuilder aggregation, int maxConcurrentShardRequests) {
        // 1. 准备查询: 查询条件、聚合方法
        SearchSourceBuilder builder = new SearchSourceBuilder();
        builder.query(condition);
        builder.size(0);
        builder.aggregation(aggregation);

        // 2. 执行查询：解析结果
        SearchRequest request = new SearchRequest(new String[]{indexName}, builder);
        request.setMaxConcurrentShardRequests(maxConcurrentShardRequests);
        if (properties.getRequestCache() != null) {
            request.requestCache(properties.getRequestCache());
        }
        try {
            SearchResponse response = esClient.search(request, RequestOptions.DEFAULT);
            Aggregation result = response.getAggregations().asMap().get(aggregation.getName());
            List<Map<String, Object>> list = new ArrayList<>();
            if (result instanceof MultiBucketsAggregation) {
                for (MultiBucketsAggregation.Bucket bucket : ((MultiBucketsAggregation) result).getBuckets()) {
                    Map<String, Object> map = new HashMap<>();
                    map.put("key_as_string", bucket.getKeyAsString());
                    map.put("key", bucket.getKey());
                    map.put("doc_count", bucket.getDocCount());
                    if (bucket.getAggregations().getAsMap().isEmpty()) {
                        map.put("value", bucket.getDocCount());
                    } else {
                        JsonNode jsonNode = JacksonHelper.transferToJsonNode(bucket.getAggregations().getAsMap());
                        String value = jsonNode.findValue("value") == null ? jsonNode.findValue("docCount").toString() : jsonNode.findValue("value").toString();
                        map.put("value", value);
                    }
                    list.add(map);
                }
            } else {
                Map<String, Object> map = new HashMap<>();
                ParsedCardinality parsedCardinality = (ParsedCardinality) result;
                map.put("name", parsedCardinality.getName());
                map.put("value", parsedCardinality.getValue());
                list.add(map);
            }
            log.info("aggs ok: indexName={}, count={}, timeout={}, tookMillis={}, totalHit={}",
                    indexName, list.size(), response.isTimedOut(), response.getTook().millis(), response.getHits().getTotalHits().value);
            return list;
        } catch (Throwable t) {
            log.warn("aggs failed: indexName={}, error=", indexName, t);
            throw new RuntimeException(String.format("aggs failed: indexName=%s, error=%s", indexName, t));
        }
    }

    /**
     * 流式查询建议加上对_doc字段的排序，否则查询效率非常低
     */
    public void selectByStream(String indexName, FetchSourceContext source, QueryBuilder condition, Map<String, SortOrder> sorts, ResultHandler handler) {
        // 1. 准备查询: 查询字段、查询条件、排序、波次数据量
        SearchSourceBuilder builder = new SearchSourceBuilder();
        builder.fetchSource(source);
        builder.query(condition);
        for (Map.Entry<String, SortOrder> entry : sorts.entrySet()) {
            builder.sort(entry.getKey(), entry.getValue());
        }
        builder.size(handler.waveCount());

        // 2. 执行查询：解析结果
        Scroll scroll = new Scroll(TimeValue.timeValueMinutes(1L));
        SearchRequest request = new SearchRequest(new String[]{indexName}, builder).scroll(scroll);
        if (properties.getRequestCache() != null) {
            request.requestCache(properties.getRequestCache());
        }
        try {
            // 2.1：第一次查询返回第一个波次数据
            SearchResponse response = esClient.search(request, RequestOptions.DEFAULT);
            SearchHit[] hits = response.getHits().getHits();
            String scrollId = response.getScrollId();
            this.handleWaveResult(hits, handler);
            log.info("selectByStream search: indexName={}, tookMillis={}, hitsSize={}, totalHit={}, hitId={}, scrollId={}",
                    indexName, response.getTook().millis(), hits.length, response.getHits().getTotalHits().value, hits.length > 0 ? hits[0].getId() : null, scrollId);

            // 2.2：开始滚动查询后续波次的数据
            while (hits.length > 0 && !handler.stopHandle()) {
                SearchScrollRequest scrollRequest = new SearchScrollRequest(scrollId).scroll(scroll);
                SearchResponse scrollResponse = esClient.scroll(scrollRequest, RequestOptions.DEFAULT);
                hits = scrollResponse.getHits().getHits();
                scrollId = scrollResponse.getScrollId();
                this.handleWaveResult(hits, handler);
                log.info("selectByStream scroll: indexName={}, tookMillis={}, hitsSize={}, totalHit={}, hitId={}, scrollId={}",
                        indexName, scrollResponse.getTook().millis(), hits.length, scrollResponse.getHits().getTotalHits().value, hits.length > 0 ? hits[0].getId() : null, scrollId);
            }

            // 2.3： 请理scroll
            ClearScrollRequest clearScrollRequest = new ClearScrollRequest();
            clearScrollRequest.addScrollId(scrollId);
            ClearScrollResponse clearScrollResponse = esClient.clearScroll(clearScrollRequest, RequestOptions.DEFAULT);
            log.info("selectByStream clearScroll: indexName={}, waveCount={}, stopHandle={}, cleanScroll={}, scrollId={}",
                    indexName, handler.waveCount(), handler.stopHandle(), clearScrollResponse.isSucceeded(), scrollId);
        } catch (Throwable t) {
            log.warn("selectByStream failed: indexName={}, waveCount={}, error=", indexName, handler.waveCount(), t);
            throw new RuntimeException(String.format("selectByStream failed: indexName=%s, waveCount=%s, error=%s", indexName, handler.waveCount(), t));
        }
    }

    private void handleWaveResult(SearchHit[] hits, ResultHandler handler) {
        for (SearchHit hit : hits) {
            Map<String, Object> map = new HashMap<>();
            map.put("_id", hit.getId());
            if (hit.getSourceAsMap() != null) {
                map.putAll(hit.getSourceAsMap());
            }
            handler.handleResult(map);
        }
    }

    public interface ResultHandler {
        void handleResult(Map<String, Object> map);

        boolean stopHandle();

        int waveCount();
    }

}
